#pragma once

#include "dtype.hpp"
#include "exception.hpp"
#include "fmt/format.h"
#include "global_variable.hpp"
#include "shmobj.hpp"
#include "spdlog/common.h"
#include "spdlog/fmt/bundled/core.h"
#include "spdlog/fmt/bundled/format.h"
#include "spdlog/logger.h"
#include "spdlog/sinks/basic_file_sink.h"
#include "spdlog/sinks/dup_filter_sink.h"
#include "spdlog/spdlog.h"
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <exception>
#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <pthread.h>
#include <pybind11/detail/common.h>
#include <pybind11/numpy.h>
#include <pybind11/pytypes.h>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

class Handle {
private:
  static std::hash<std::string> hasher;

  struct VarMeta {
    std::string id;
    uint32_t expect_index;
    uint32_t local_version;
    PyDataType dtype;
    size_t nbytes;
    bool is_buf_proto;
    std::optional<std::unique_ptr<ShmObj>> shmobj;
  };

  struct VarBinary {
    char name[64];
    char id[64];
    bool is_buf_proto;
    std::atomic_size_t nbytes;
    std::atomic_uint32_t ref_count;
    std::atomic_uint32_t version;
    PyDataType dtype;
    pthread_mutex_t mutex;

    void lock() { pthread_mutex_lock(&mutex); }

    void unlock() { pthread_mutex_unlock(&mutex); }
  };

  struct HandleBinary {
    std::uint32_t capacity;    // variable list capacity
    std::atomic_uint32_t size; // variable list size
    pthread_mutex_t mutex;
    // var_meta array
  };

  std::string _name;
  std::unique_ptr<ShmObj> _shm;
  HandleBinary *_mmh;
  std::unordered_map<std::string, VarMeta> _var_map;
  std::shared_ptr<spdlog::logger> logger;

  VarBinary *varlist() {
    return static_cast<VarBinary *>(static_cast<void *>(
        static_cast<char *>(static_cast<void *>(this->_mmh)) +
        sizeof(uint32_t) + sizeof(std::atomic_uint32_t) +
        sizeof(pthread_mutex_t)));
  }

  bool update(std::string_view name, VarMeta *meta, bool attach = true) {
    auto [varbin, idx] = this->find_varbin(name, meta);
    if (varbin == nullptr) {
      return false;
    }
    std::lock_guard<VarBinary> g(*varbin);
    if (varbin->id != meta->id) {
      meta->dtype = varbin->dtype;
      meta->id = varbin->id;
      meta->expect_index = idx;
      meta->is_buf_proto = varbin->is_buf_proto;
      meta->local_version = varbin->version;
      meta->nbytes = varbin->nbytes;
      if (attach)
        meta->shmobj = std::make_unique<ShmObj>(meta->id);
      return true;
    }
    // if varbin is newer
    if (meta->local_version < varbin->version) {
      meta->local_version = varbin->version;
      meta->dtype = varbin->dtype;
      meta->is_buf_proto = varbin->is_buf_proto;
      meta->nbytes = varbin->nbytes;
      meta->shmobj.reset();
    }
    if (!meta->shmobj.has_value() && attach) {
      meta->shmobj = std::make_unique<ShmObj>(meta->id);
    }
    return true;
  }

  std::pair<VarBinary *, uint32_t> find_varbin(std::string_view name,
                                               std::optional<VarMeta *> meta) {
    if (meta.has_value()) {
      auto varbin = varlist() + meta.value()->expect_index;
      if (meta.has_value() && varbin->id == meta.value()->id) {
        return {varbin, meta.value()->expect_index};
      }
      logger->warn("varbin index({}) change. expect id({}), but found({}). "
                   "fallback to scan varlist",
                   meta.value()->expect_index, meta.value()->id, varbin->id);
    }
    for (uint32_t i = 0; i < size(); i++) {
      auto tmpbin = varlist() + i;
      if (tmpbin->name == name) {
        logger->debug("varbin({}) found at index({})", name, i);
        return {tmpbin, i};
      }
    }
    logger->error("varbin({}) not found", name);
    return {nullptr, 0};
  }

  py::object parse(VarMeta const *meta) {
    if (meta->dtype == PyDataType::PyBool) {
      return py::bool_(*static_cast<bool *>(meta->shmobj.value()->data()));
    } else if (meta->dtype == PyDataType::PyInteger) {
      return py::int_(*static_cast<int64_t *>(meta->shmobj.value()->data()));
    } else if (meta->dtype == PyDataType::PyFloat) {
      return py::float_(*static_cast<double *>(meta->shmobj.value()->data()));
    } else if (meta->dtype == PyDataType::PyStr) {
      void *buf = meta->shmobj.value()->data();
      size_t str_length = *static_cast<size_t *>(buf);
      char *str_ptr = static_cast<char *>(buf) + sizeof(size_t);
      return py::str(std::string_view{str_ptr, str_length});
    } else if (meta->dtype == PyDataType::NdArray) {
      auto buff = static_cast<BuffProto *>(meta->shmobj.value()->data());
      std::vector<ssize_t> shape(buff->shape(), buff->shape() + buff->ndims);
      std::vector<ssize_t> strides(buff->strides(),
                                   buff->strides() + buff->ndims);
      return py::array(py::memoryview::from_buffer(
          buff->ptr(), buff->itemsize, buff->format, shape, strides));
    } else {
      return py::none();
    }
  }

  void init_logger() {
    auto sink = std::make_shared<spdlog::sinks::dup_filter_sink_mt>(
        std::chrono::seconds(5));
    sink->add_sink(
        std::make_shared<spdlog::sinks::basic_file_sink_mt>(LOGFILE));
    this->logger =
        std::make_shared<spdlog::logger>(std::string(this->name()), sink);
    this->logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%n] %P [%l] | %v");
    this->logger->set_level(LOG_LEVEL);
    this->logger->flush_on(spdlog::level::warn);
    spdlog::register_logger(this->logger);
    using namespace std::chrono_literals;
    spdlog::flush_every(3s);
  }

public:
  Handle(std::string_view name, uint32_t &&capacity) : _name(name) {
    this->init_logger();
    uint64_t reqbytes = sizeof(HandleBinary) + sizeof(VarBinary) * capacity;
    logger->debug("creating shmpy.Handle({}). req bytes: {}", name, reqbytes);
    try {
      this->_shm = std::make_unique<ShmObj>(name, nullptr, reqbytes);
    } catch (std::exception const &e) {
      throw HandleError(
          fmt::format("fail to create shared memory object. {}", e.what()));
    }
    logger->debug("ShmObj create success");
    this->_mmh = static_cast<HandleBinary *>(this->_shm->data());
    this->_mmh->capacity = capacity;
    this->_mmh->size = 0;
    pthread_mutex_init(&_mmh->mutex, &PTHREAD_MUTEXATTR);
    logger->debug("shmpy.Handle create success");
  }

  Handle(std::string_view name) : _name(name) {
    this->init_logger();
    logger->debug("attaching to existing shmpy.Handle({})", name);
    try {
      this->_shm = std::make_unique<ShmObj>(name);
    } catch (std::exception const &e) {
      throw HandleError(
          fmt::format("fail to attach to shared memory object. {}", e.what()));
    }
    logger->debug("attaching to ShmObj success");
    this->_mmh = static_cast<HandleBinary *>(this->_shm->data());
    logger->debug("creating lock_guard for shmpy.Handle");
    std::lock_guard<Handle> g(*this);
    logger->debug("shmpy.Handle lock_guard created");
    logger->debug("loading existing variable meta, count: {}", _mmh->size);
    if (_mmh->size > 0) {
      for (uint32_t i = 0; i < _mmh->size; i++) {
        auto varbin = this->varlist() + i;
        this->_var_map.emplace(
            varbin->name, VarMeta{varbin->id, i, varbin->version, varbin->dtype,
                                  varbin->nbytes, varbin->is_buf_proto});
      }
    }
    logger->debug("exsting variable meta load success");
  }

  ~Handle() {
    logger->debug("disposing shmpy.Handle({})", name());
    spdlog::drop(this->logger->name());
    this->logger->flush();
  }

  void insert(std::string &&name, const py::object &obj) {
    if (this->_var_map.contains(name)) {
      throw py::key_error(fmt::format("variable \"{}\" already exist", name));
    }
    if (name.length() >= 64 - _name.length()) {
      throw VariableError(fmt::format("variable name too long. expect < {}",
                                      64 - _name.length()));
    }
    std::lock_guard<Handle> hG(*this);
    // check capacity
    if (this->size() >= this->capacity()) {
      throw py::index_error("max capacity");
    }

    // check if variable name unique
    for (uint32_t i = 0; i < size(); i++) {
      auto varbin = varlist() + i;
      if (varbin->name == name) {
        throw py::key_error(fmt::format("variable \"{}\" already exist", name));
      }
    }

    // init VarBinary for new insert one
    auto varbin = this->varlist() + size();
    // varbin lock guard
    std::lock_guard<VarBinary> g(*varbin);
    strncpy(varbin->name, name.c_str(), 64);

    int64_t now_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(
                         std::chrono::system_clock::now())
                         .time_since_epoch()
                         .count();
    fmt::format_to_n(varbin->id, 64, "{}-{:X}", this->name(), now_ns);

    varbin->nbytes = get_pyobj_size(obj);
    varbin->dtype = get_pyobj_dtype(obj);
    varbin->version = 0;
    varbin->is_buf_proto = varbin->dtype == PyDataType::NdArray;

    // create shmobj for the variable
    auto varshm = std::make_unique<ShmObj>(varbin->id, nullptr, varbin->nbytes);
    get_pyobj_data(obj, varbin->dtype, varshm->data(), varbin->nbytes);
    this->_var_map.emplace(
        name, VarMeta{varbin->id, size(), varbin->version, varbin->dtype,
                      varbin->nbytes, varbin->is_buf_proto, std::move(varshm)});
    this->_mmh->size++;
  }

  void set(std::string_view name, const py::object &obj) {}

  void del(std::string &&name) {
    auto pair = this->_var_map.find(name);
    if (pair == this->_var_map.end()) {
      return;
    }
    auto varmeta = &pair->second;
    { std::lock_guard<Handle> g(*this); }
  }

  py::object get(std::string &&name) {
    std::lock_guard<Handle> g(*this);
    // try find it locally
    auto iter = this->_var_map.find(name);
    // local found
    if (iter != _var_map.end()) {
      if (!this->update(iter->first, &iter->second)) {
        this->_var_map.erase(iter->first);
        return py::none();
      }
      auto meta = &iter->second;
      return this->parse(meta);
    }
    // local not found
    for (uint32_t i = 0; i < size(); i++) {
      auto varbin = varlist() + i;
      if (varbin->name == name) {
        auto [k, v] = this->_var_map.emplace(
            name, VarMeta{varbin->id, i, 0, varbin->dtype, varbin->nbytes,
                          varbin->is_buf_proto,
                          std::make_unique<ShmObj>(varbin->id)});
        if (v)
          return this->parse(&k->second);
        return py::none();
      }
    }
    return py::none();
  }

  uint32_t ref_count() const noexcept { return this->_shm->ref_count(); }

  uint32_t capacity() const noexcept { return this->_mmh->capacity; }

  uint32_t size() const noexcept { return this->_mmh->size; }

  std::string_view name() const noexcept { return this->_name; }

  spdlog::level::level_enum get_loglevel() { return this->logger->level(); }

  void set_loglevel(spdlog::level::level_enum level) {
    this->logger->set_level(level);
  }

  void lock() { pthread_mutex_lock(&this->_mmh->mutex); }
  void unlock() { pthread_mutex_unlock(&this->_mmh->mutex); }
};